Skip to content

Feature/unswizzle#2732

Open
int-smart wants to merge 12 commits intoNVIDIA:mainfrom
int-smart:feature/unswizzle
Open

Feature/unswizzle#2732
int-smart wants to merge 12 commits intoNVIDIA:mainfrom
int-smart:feature/unswizzle

Conversation

@int-smart
Copy link

@int-smart int-smart commented Mar 4, 2026

Description

This PR adds unswizzle support for scaling factors and extends the swizzle module so scaling tensors can be converted from GEMM-swizzled layout back to compact layout, including multi-tensor paths. It also adds round-trip and standalone tests to validate unswizzle correctness.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added unswizzle APIs and implementation in transformer_engine/common/swizzle/swizzle.cu and declarations in transformer_engine/common/include/transformer_engine/swizzle.h
  • Added multi-tensor unswizzle support with swizzle-like validation assumptions (homogeneous scaling mode/layout, swizzled input and compact output expectations)
  • Refactored multi-tensor unswizzle launch/kernels to mirror swizzle structure (split row-wise and column-wise kernels) for easier readability
  • Added/extended tests in tests/cpp/operator/test_swizzle.cu, including standalone unswizzle and swizzle→unswizzle round-trip coverage

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

int-smart and others added 6 commits March 3, 2026 20:40
- Introduced `nvte_unswizzle_scaling_factors` to convert swizzled scaling factors back to row-major format.
- Implemented `regs_unshuffle_with_bit_shifts` and `regs_unshuffle` for unshuffling operations in CUDA kernels.
- Added `unswizzle_row_scaling_kernel_impl` and `unswizzle_col_scaling_kernel_impl` for handling unswizzling in row and column scaling respectively.

These changes enhance the functionality of the swizzle module, enabling better handling of scaling factors in tensor operations.

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
These enhancements tests the changes introduced for unswizzling

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
- Introduced `compute_ref_unswizzle` to handle the conversion of swizzled scaling factors back to their original format.
- Added `performTestUnswizzle1D` to validate the unswizzling process with various scaling modes.
- Created `UnswizzleTestSuite` for comprehensive testing of unswizzling operations.

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
- Moved the definition of `swizzle_row_scaling_kernel` to a new location for better organization.
- Ensured the kernel implementation is now properly defined and accessible for scaling operations in the swizzle module.

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
- Introduced `multi_tensor_unswizzle_scaling_factors` to convert swizzled scaling factors back to their original row-major format.
- Implemented CUDA kernels for unswizzling in both row and column scaling, enhancing the swizzle module's functionality.
- Updated the launch function to handle multiple tensor unswizzling operations efficiently.

These changes improve the handling of scaling factors in tensor operations, ensuring better performance and organization within the swizzle module.

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Greptile Summary

This PR adds unswizzle support for MXFP8 / NVFP4 scaling factors, implementing the inverse of the existing swizzle operation. It introduces new CUDA kernels (unswizzle_row_scaling_kernel_impl, unswizzle_col_scaling_kernel_impl), single-tensor and multi-tensor host-side dispatch functions, public C API declarations, and a comprehensive test suite including padded-shape standalone unswizzle tests and swizzle→unswizzle round-trip tests.

  • The kernel logic is a correct structural inverse of the swizzle kernels: shared-memory roles are swapped (linear load from swizzled global → SLM, then read via swizzle index vs. write via swizzle index → linear store to swizzled global), and regs_unshuffle / regs_unshuffle_with_bit_shifts are the correct inverses of their regs_shuffle* counterparts.
  • Output-size validation in unswizzle_scaling_factors correctly uses padded m * k from the swizzled input's shape, aligning with the multi-tensor path and the swizzle counterpart.
  • The test suite uses padded_dim0 * padded_dim1 for the unswizzle comparison (covering padding bytes), and the unswizzle data shapes include several padded-boundary cases (M-only, K-only, both).
  • One logic issue: unswizzle_scaling_factors (and multi_tensor_unswizzle_scaling_factors) rejects input tensors that carry both rowwise and columnwise scaling factors, but swizzle_scaling_factors processes both scale types in a single call. This asymmetry breaks the round-trip invariant for dual-scale MXFP8 tensors and is undocumented in the public header.

Confidence Score: 3/5

  • Mostly safe to merge, but the asymmetric dual-scale rejection should be resolved before this lands in a release used for round-trip workflows.
  • The kernel math and memory-layout inversions are correct and well-tested for single-scale tensors. The padding-aware test shapes cover the previously-reported validation bug. The remaining concern is an API contract issue: nvte_unswizzle_scaling_factors cannot round-trip a tensor that was produced by nvte_swizzle_scaling_factors if that tensor had both rowwise and columnwise scales — a configuration that is valid for MXFP8 dual-path training. This does not affect correctness for single-scale use cases, but it is a latent bug for callers who use the full MXFP8 dual-scale path.
  • transformer_engine/common/swizzle/swizzle.cu — the dual-scale rejection check in unswizzle_scaling_factors (lines 1163–1166) and the corresponding check in multi_tensor_unswizzle_scaling_factors (lines 1391–1392).

Important Files Changed

Filename Overview
transformer_engine/common/swizzle/swizzle.cu Adds unswizzle kernels (row + col), launch helpers, single-tensor unswizzle_scaling_factors, and multi-tensor multi_tensor_unswizzle_scaling_factors. The kernel logic correctly inverts the swizzle (shared-memory layout, regs_unshuffle / regs_unshuffle_with_bit_shifts order, and pointer offset formulas mirror the swizzle path). One logic issue found: unswizzle_scaling_factors errors on tensors that have both rowwise and columnwise scales, while swizzle_scaling_factors handles both simultaneously — breaking round-trip symmetry for dual-scale MXFP8 tensors.
transformer_engine/common/include/transformer_engine/swizzle.h Adds public declarations for nvte_unswizzle_scaling_factors and nvte_multi_tensor_unswizzle_scaling_factors with correct doxygen comments. Minor: the requirements section does not document that tensors with both rowwise and columnwise scales are rejected (a contract enforced in the implementation but not documented here).
tests/cpp/operator/test_swizzle.cu Adds compute_ref_unswizzle (correct inverse of compute_ref_swizzle), performTestUnswizzle1D with padded test shapes, performTestSwizzleUnswizzleRoundtrip, and two new test suites. The uninitialized-variable / missing-space issues from previous review threads are resolved by the separate !rowwise && !columnwise early return. The padded unswizzle test shapes cover M-only, K-only, and both-padded cases.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant nvte_unswizzle_scaling_factors
    participant unswizzle_scaling_factors
    participant unswizzle_row_scaling_kernel
    participant unswizzle_col_scaling_kernel

    Caller->>nvte_unswizzle_scaling_factors: (swizzled_tensor, output_tensor, stream)
    nvte_unswizzle_scaling_factors->>unswizzle_scaling_factors: dispatch
    unswizzle_scaling_factors->>unswizzle_scaling_factors: validate scaling_mode, swizzled flag,\nm%128==0, k%4==0, output size == m*k
    alt rowwise_unswizzle
        unswizzle_scaling_factors->>unswizzle_row_scaling_kernel: <<<(DIVUP(tiles_k,n_tb), tiles_m)>>>\n(swizzled_ptr, compact_ptr, m, k)
        note over unswizzle_row_scaling_kernel: 1) Linear load: swizzled global → SLM\n2) __syncthreads\n3) SLM tile → regs (swizzle index)\n4) regs_unshuffle<LType>\n5) Write regs → compact global
    else columnwise_unswizzle
        unswizzle_scaling_factors->>unswizzle_col_scaling_kernel: <<<(DIVUP(tiles_k,TB_DIM), DIVUP(tiles_m,vls))>>>\n(swizzled_ptr, compact_ptr, m, k)
        note over unswizzle_col_scaling_kernel: 1) Linear load: swizzled global → SLM\n2) __syncthreads\n3) SLM (swizzle index) → regs\n4) regs_unshuffle_with_bit_shifts\n5) Write regs → compact global (M-major)
    end
    unswizzle_scaling_factors-->>Caller: compact scale_inv written to output
Loading

Last reviewed commit: 4410e9d

@vthumbe1503 vthumbe1503 added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Mar 4, 2026
Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
@int-smart int-smart force-pushed the feature/unswizzle branch from 85ea04b to 17dbb33 Compare March 5, 2026 02:13
int-smart and others added 2 commits March 4, 2026 18:49
@ptrendx
Copy link
Member

ptrendx commented Mar 11, 2026

@int-smart Please address the comments from Greptile and ideally also add the test case with the input not already padded to 128,128.

@int-smart
Copy link
Author

@ptrendx Will look into these

@int-smart
Copy link
Author

@ptrendx From what I am understanding then, there is no relevance of padding to the unswizzle kernel. Since the padding is already done during the swizzling operation I can just mirror it back to compact layout with the zero pads correctly in the compact layout and that should do. Is that assumption correct. Initially I was thinking of removing the padding from the scale_inv itself since this would be used for checkpointing

int-smart and others added 2 commits March 12, 2026 19:53
- Updated unswizzling kernel implementations to remove original_M and original_K parameters, simplifying the function signatures.
- Enhanced test suite to utilize new unswizzling data shapes, ensuring comprehensive coverage of aligned and padded cases.

These changes improve the clarity and efficiency of the unswizzling process in the swizzle module.
Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
Comment on lines +1163 to +1166
const bool has_rowwise_scale_inv = input->scale_inv.has_data();
const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data();
NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv,
"Input tensor has both row-wise and column-wise scaling factors");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Asymmetric handling of dual-scale tensors breaks round-trip correctness

unswizzle_scaling_factors explicitly rejects tensors that have both rowwise and columnwise scaling factors (line 1165–1166), but the counterpart swizzle_scaling_factors happily processes both scale types in a single call (it runs both the rowwise and columnwise swizzle paths sequentially).

This means calling the public round-trip pair —

nvte_swizzle_scaling_factors(input, swizzled, stream);   // succeeds: handles both scales
nvte_unswizzle_scaling_factors(swizzled, output, stream); // FAILS: "Input tensor has both..."

— will raise a runtime error for any MXFP8 tensor that carries both rowwise and columnwise scale factors (a common configuration in dual-path training).

The same asymmetry is present in the multi-tensor variant (multi_tensor_unswizzle_scaling_factors, line 1391–1392).

The fix is either:

  1. Support both scale types in the unswizzle path (mirror swizzle_scaling_factors), or
  2. Document the restriction in the header API comment so callers know to split the tensor or call two separate unswizzle invocations.

As-is, a user who relies on swizzleunswizzle being a perfect inverse pair for the general case will encounter a silent API contract violation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants